import torch as pt
import numpy as np
import torchvision as ptv
import sys


def sep(label = '', cnt=32):
    print('-' * cnt, label, '-' * cnt, sep='')


class Resnet34Clf(pt.nn.Module):

    def __init__(self, n_cls, **kwargs):
        super().__init__(**kwargs)

        resnet = ptv.models.resnet34(pretrained=True)

        for param in resnet.parameters():
            param.requires_grad = False

        output_ch = resnet.fc.in_features
        # print(resnet)
        # sys.exit(0)
        self.resnet = pt.nn.Sequential(*(list(resnet.children())[:-1]))
        self.fc = pt.nn.Linear(output_ch, n_cls)

    def forward(self, x):
        x = self.resnet(x)
        # print('after resnet', x.size())
        x = pt.squeeze(x, dim=3)
        x = pt.squeeze(x, dim=2)
        x = self.fc(x)
        return x


if '__main__' == __name__:

    from torch.utils.data import TensorDataset, DataLoader
    from sklearn.model_selection import train_test_split
    import os
    import cv2 as cv

    device = 'cuda:0' if pt.cuda.is_available() else 'cpu'
    print('device', device)
    device = pt.device(device)

    model = Resnet34Clf(2).to(device)

    sep('Test model simply')
    x = np.zeros((4, 3, 224, 224), dtype=np.float32)
    x = pt.Tensor(x).to(device)
    pred = model(x)
    print('pred', pred.size())

    sep('load cat and dog')
    BATCH_SIZE = 32
    EPOCH = 8
    ALPHA = 1e-3

    data_path = r'../../../../../large_data/DL1/_many_files/zoo'
    IMG_H = 224
    IMG_W = 224


    def get_pic_data(dir):
        files = os.listdir(dir)
        x = []
        for file in files:
            path = os.path.join(dir, file)
            img = cv.imread(path, cv.IMREAD_COLOR)
            img = cv.resize(img, (IMG_H, IMG_W))
            img = img.astype(np.float32) / 255.
            x.append(img)
        x = np.float32(x)
        return x


    x_cat = get_pic_data(os.path.join(data_path, 'cat'))
    print('x_cat:', x_cat.shape)
    n_cat = len(x_cat)
    x_dog = get_pic_data(os.path.join(data_path, 'dog'))
    print('x_dog:', x_dog.shape)
    n_dog = len(x_dog)
    x = np.concatenate([x_cat, x_dog], axis=0)
    x = np.transpose(x, [0, 3, 1, 2])
    y_cat = np.full([n_cat], 0, dtype=np.int32)
    print('y_cat:', y_cat.shape)
    y_dog = np.full([n_dog], 1, dtype=np.int32)
    print('y_dog:', y_dog.shape)
    y = np.concatenate([y_cat, y_dog], axis=0)
    print('x:', x.shape)
    print('y:', y.shape)
    x_train, x_val_test, y_train, y_val_test = train_test_split(x, y, train_size=0.8, random_state=1, shuffle=True)
    x_val, x_test, y_val, y_test = train_test_split(x_val_test, y_val_test, train_size=0.5, random_state=1,
                                                    shuffle=True)
    print('x_train', x_train.shape)
    print('x_val', x_val.shape)
    print('x_test', x_test.shape)
    print('y_train', y_train.shape)
    print('y_val', y_val.shape)
    print('y_test', y_test.shape)

    x_train = pt.Tensor(x_train)
    x_val = pt.Tensor(x_val)
    x_test = pt.Tensor(x_test)
    y_train = pt.Tensor(y_train)
    y_val = pt.Tensor(y_val)
    y_test = pt.Tensor(y_test)
    ds_train = TensorDataset(x_train, y_train)
    ds_val = TensorDataset(x_val, y_val)
    ds_test = TensorDataset(x_test, y_test)
    dl_train = DataLoader(ds_train, BATCH_SIZE, shuffle=True)
    dl_val = DataLoader(ds_val, BATCH_SIZE, shuffle=True)
    dl_test = DataLoader(ds_test, BATCH_SIZE, shuffle=True)

    criterion = pt.nn.CrossEntropyLoss().to(device)
    # optim = pt.optim.Adam(model.parameters(), lr=ALPHA)
    optim = pt.optim.SGD(model.parameters(), lr=ALPHA, momentum=0.9, weight_decay=1e-4)


    def accuracy(y_true, y_pred):
        y_pred = y_pred.argmax(dim=1)
        y_true = y_true
        acc = (y_pred == y_true).float().mean()
        return acc


    def process_data(dl, is_train, label):
        len_dl = len(dl)
        GROUP = int(np.ceil(len_dl / 10))
        avg_loss = 0.
        avg_acc = 0.
        for i, (bx, by) in enumerate(dl):
            bx = bx.float().to(device)
            by = by.long().to(device)
            if is_train:
                model.train(True)
                optim.zero_grad()
                h = model(bx)
                loss = criterion(h, by)
                loss.backward()
                optim.step()
                acc = accuracy(by, h)
                model.train(False)
            else:
                model.train(False)
                h = model(bx)
                loss = criterion(h, by)
                acc = accuracy(by, h)
            lossv = loss.detach().cpu().numpy()
            accv = acc.detach().cpu().numpy()
            avg_loss += lossv
            avg_acc += accv
            if i % GROUP == 0:
                print(f'{label}: epoch#{epoch + 1}: #{i + 1} loss = {lossv}, acc = {accv}')
        if i % GROUP != 0:
            print(f'{label}: epoch#{epoch + 1}: #{i + 1} loss = {lossv}, acc = {accv}')
        avg_loss /= i + 1
        avg_acc /= i + 1
        return avg_loss, avg_acc


    loss_his = []
    acc_his = []
    loss_his_val = []
    acc_his_val = []
    for epoch in range(EPOCH):
        sep(epoch + 1)
        avg_loss, avg_acc = process_data(dl_train, True, 'train')
        avg_loss_val, avg_acc_val = process_data(dl_val, False, 'val')
        loss_his.append(avg_loss)
        loss_his_val.append(avg_loss_val)
        acc_his.append(avg_acc)
        acc_his_val.append(avg_acc_val)
        print(
            f'epoch#{epoch + 1}: loss = {avg_loss} acc = {avg_acc}, loss_val = {avg_loss_val}, acc_val = {avg_acc_val}')

    sep('Test')
    avg_loss_test, avg_acc_test = process_data(dl_test, False, 'test')
    print(f'Test loss = {avg_loss_test}, acc = {avg_acc_test}')

    import matplotlib.pyplot as plt

    plt.figure(figsize=[12, 6])
    spr = 1
    spc = 2
    spn = 0

    spn += 1
    plt.subplot(spr, spc, spn)
    plt.plot(loss_his, label='train loss')
    plt.plot(loss_his_val, label='val loss')
    plt.legend()

    spn += 1
    plt.subplot(spr, spc, spn)
    plt.plot(acc_his, label='train acc')
    plt.plot(acc_his_val, label='val acc')
    plt.legend()

    plt.show()

